Skip to content

Add support for DistributedEmbedding for Ironwood and expose has_sparsecores.#183

Merged
hertschuh merged 1 commit intokeras-team:mainfrom
hertschuh:tpuv7
Mar 2, 2026
Merged

Add support for DistributedEmbedding for Ironwood and expose has_sparsecores.#183
hertschuh merged 1 commit intokeras-team:mainfrom
hertschuh:tpuv7

Conversation

@hertschuh
Copy link
Collaborator

@hertschuh hertschuh commented Feb 26, 2026

  • Use num_sparsecores_per_device from jax-tpu-embedding instead of having a duplicated hardcoded list of supported TPUs.
  • Added public class method DistributedEmbedding.has_sparsecores.
  • Added warning when running JAX on TPU with jax-tpu-embedding not installed.
  • Made error messages more specific and consistent with Keras errors.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the DistributedEmbedding layer by adding explicit support for the TPU7x (Ironwood) architecture. This ensures that the system correctly identifies and leverages the sparse core capabilities of these new TPU devices, enabling efficient distributed embedding operations on the latest hardware.

Highlights

  • TPU7x (Ironwood) Support: Extended DistributedEmbedding to recognize and utilize TPU7x (Ironwood) devices, identified as TPUv7x, for sparse core operations.
  • Sparse Core Detection Logic Update: Updated the internal _has_sparsecore method in base_distributed_embedding.py and the has_sparsecores utility function in test_utils.py to include TPU7x in the list of supported device kinds.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras_rs/src/layers/embedding/base_distributed_embedding.py
    • Added 'TPU7x' to the list of device kinds recognized as having sparse core capabilities within the _has_sparsecore method.
  • keras_rs/src/layers/embedding/jax/test_utils.py
    • Included 'TPU7x' in the device_kind check within the has_sparsecores utility function to align with new hardware support.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Ironwood TPUs (identified as TPU7x) to DistributedEmbedding by updating the list of devices that have SparseCore. The change is straightforward, but it introduces code duplication across two files. My feedback suggests creating a single constant for the list of TPU device kinds to improve maintainability. I've also pointed out a potential inconsistency in the new device kind string "TPU7x" and asked for verification.

Comment on lines 24 to 27
device_kind = jax.devices()[0].device_kind
if device_kind in ["TPU v5", "TPU v6 lite"]:
if device_kind in ["TPU v5", "TPU v6 lite", "TPU7x"]:
return True
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function can be simplified to a single return statement for better readability.

    device_kind = jax.devices()[0].device_kind
    return device_kind in ["TPU v5", "TPU v6 lite", "TPU7x"]

@hertschuh hertschuh force-pushed the tpuv7 branch 2 times, most recently from a7147e0 to 83f0017 Compare February 26, 2026 00:42
@hertschuh hertschuh changed the title Add support for DistributedEmbedding for Ironwood. Add support for DistributedEmbedding for Ironwood and expose has_sparsecores. Feb 26, 2026
@hertschuh hertschuh requested a review from cantonios February 26, 2026 00:44
if len(tpu_devices) > 0:
device_kind = tpu_devices[0].device_kind
if device_kind in ["TPU v5", "TPU v6 lite"]:
if device_kind in ["TPU v5", "TPU v5p", "TPU v6 lite", "TPU7x"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here there might be a discrepancy between keras RS claiming the TPU has sparsecores vs the number of available sparsecores reported by jax_tpu_embedding. It might be worth having the try-except block here.

What we don't want is something like TPU7x being reported as having sparsecores, but then using an incorrect # SC per device.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but it might get out-of-sync. Wouldn't it just be better to check the number of SCs per device, and if that returns 0 or throws an exception, then return False here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You wrote that code 😛

And I think the reason was that you wanted this to work even if jax-tpu-embedding is not installed. Then, if the users specified placement="auto" or placement="sparsecore" but jax-tpu-embedding is not installed, it gives you a different error message.

@hertschuh hertschuh force-pushed the tpuv7 branch 5 times, most recently from a848522 to 608dc2c Compare February 27, 2026 00:33
@hertschuh hertschuh requested a review from cantonios February 27, 2026 01:38
return 1
except ValueError:
# Default to one for non-sparsecore tests.
return 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to raise the error, and just not test the non-sparsecore tests if there are no sparsecores?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These test are also run on CPU. Do you know why?

I can just skip them if not on SparseCore.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, probably because we didn't have the CI configured to actually test on TPU, but wanted to check that the basic functionality works?

I leave it to you to decide how to handle this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are now skipped when there are no sparsecores.

…parsecores`.

- Use `num_sparsecores_per_device` from `jax-tpu-embedding` instead of having a duplicated hardcoded list of supported TPUs.
- Added public class method `DistributedEmbedding.has_sparsecores`.
- Added warning when running JAX on TPU with `jax-tpu-embedding` not installed.
- Made error messages more specific and consistent with Keras errors.
@hertschuh hertschuh merged commit 9bc0f45 into keras-team:main Mar 2, 2026
11 checks passed
@hertschuh hertschuh deleted the tpuv7 branch March 2, 2026 19:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants